import copy

import networkx as nx
import matplotlib.pyplot as plt

from ModularUtils.ControllerConstants import generate_permutations
from ModularUtils.FunctionsConstant import build_compares


def getdoKey(obs_Var, intv_key):
    query_str = "P(" + "".join(x for x in obs_Var) + "|do"

    if type(intv_key) == dict:
        query_str = query_str + "".join(x for x in intv_key.keys()) + "_" + "".join(str(x) for x in intv_key.values())
    else:
        query_str= query_str + "".join(x for x in intv_key)

    if len(intv_key)==0:
        query_str+="[]"
    query_str+=")"
    return query_str


import networkx as nx
import matplotlib.pyplot as plt







def set_nonid_mnist_images_minimized(noise_states, latent_state, obs_state, Data_intervs):
    DAG_desc = "nonid_mnist_images"

    Complete_DAG_desc = "nonid_mnist_images"
    Complete_DAG = {}
    Complete_DAG["U1"] = []
    Complete_DAG["U2"] = []
    Complete_DAG["X1"] = ["U1"]
    Complete_DAG["X2"] = ["U1", "U2", "X1"]
    Complete_DAG["W"] = ["X1", "X2"]
    # Complete_DAG["ImgYdigit1"] = ["U2", "W"]
    # Complete_DAG["ImgYdigit2"] = ["U2", "W"]
    complete_labels = list(Complete_DAG.keys())

    Observed_DAG = {}
    Observed_DAG["X1"] = []
    Observed_DAG["X2"] = ["X1"]
    Observed_DAG["W"] = ["X1", "X2"]
    Observed_DAG["ImgYdigit1"] = ["W"]
    Observed_DAG["ImgYdigit2"] = ["W"]
    label_names = list(Observed_DAG.keys())

    image_labels = ["ImgYdigit1", "ImgYdigit2"]

    label_dim = {
        "U1": {"feature": latent_state},
        "U2": {"feature": latent_state},
        "X1": {"feature": 2},  # [0,1]
        "X2": {"feature": 9},  # [0,8]
        "X1p": {"feature": 2},  # [0,1]
        "X2p": {"feature": 9},  # [0,8]
        "W": {"feature": 10},  # [0,9]
    }

    var_list = label_names[0:-2][1:]
    intervention_list = [{"obs": var_list, "inter_vars": ["X1"]}]

    for lid in range(len(intervention_list)):
        intervention_list[lid]["expr"] = getdoKey(intervention_list[lid]["obs"], intervention_list[lid]["inter_vars"])

    # intervention_list[0]["expr"]= getdoKey(intervention_list[0]["obs"], intervention_list[0]["inter_vars"])
    # intervention_list[1]["expr"]= getdoKey(intervention_list[1]["obs"], intervention_list[1]["inter_vars"])

    interv_queries = []
    for intervention in intervention_list:
        perms = generate_permutations([label_dim[lb]["feature"] for lb in intervention["inter_vars"]])
        key_val = [dict(zip(intervention["inter_vars"], comb)) for comb in perms]
        interv_queries.append({"obs": intervention["obs"], "intervs": key_val, "expr": intervention["expr"]})
        # {'obs': ['X2', 'W', 'Ydigit1', 'Ydigit2', 'Ycolor', 'Ythick'], 'intervs': [{'X1': 0}, {'X1': 1}], 'expr': 'P(V|do(X1))'}



    latent_conf = {"X1": ["U1"], "X2": ["U1", "U2"], "W": [], "ImgYdigit1": ["U2"], "ImgYdigit2": []}
    confTochild = {"U1": ["X1", "X2"], "U2": ["X2","ImgYdigit1"]}

    exogenous = {}
    for label in label_names:
        if label not in image_labels:
            exogenous[label] = "n" + label

    noise_params = {"nX1": (0.1, noise_states),
                    "nX2": (0.1, noise_states),
                    "nW": (0.1, noise_states),
                    "U1": (1, latent_state),
                    "U2": (1, latent_state)}

    # mechanism training
    intervention_datavar = []  # I cant concatenate different intvened variables distributions.

    train_mech_dict = {}
    for dist in Data_intervs:
        comp_dict = build_compares(confTochild, Observed_DAG, label_names, list(dist.keys()))
        for label in label_names:
            if label not in train_mech_dict:
                train_mech_dict[label] = []

            mech_dict = {"parents": Observed_DAG[label], "intv": dist, "compare": comp_dict[label]}
            if label in image_labels:
                continue
            train_mech_dict[label].append(mech_dict)

    # image labels are a little different different then labels as mechanism itself is not included in "compare"
    train_mech_dict["ImgYdigit1"] = [{'parents': ['W'], 'intv': {}, 'compare': ['X1', 'X2', 'W']}]
    train_mech_dict["ImgYdigit2"] = [
        {'parents': ['W'], 'intv': {}, 'compare': ['W']}]  # ideally I need ImgYdigit2 also.

    print("printing")
    for label in label_names:
        print(label, train_mech_dict[label])

    for label in Observed_DAG:
        if label not in image_labels:
            label_dim["n" + label] = {"feature": noise_states}

    cf_intervene, cf_observe, cf_evidence, cflabel_names, twin_map, Twin_Network, cf_exogenous = None, None, None, None, None, None, None
    cf_queries= None
    return DAG_desc, Complete_DAG_desc, Complete_DAG, complete_labels, Observed_DAG, label_names, image_labels, interv_queries, cf_queries, latent_conf, \
           confTochild, exogenous, cf_intervene, cf_observe, cf_evidence, cflabel_names, twin_map, Twin_Network, cf_exogenous, \
           noise_params, train_mech_dict, label_dim





def set_nonid_mnist_images(noise_states, latent_state, obs_state, Data_intervs):
    DAG_desc = "nonid_mnist_images"

    Complete_DAG_desc = "nonid_mnist_images"
    Complete_DAG = {}
    Complete_DAG["U1"] = []
    Complete_DAG["U2"] = []
    Complete_DAG["X1"] = ["U1"]
    Complete_DAG["X2"] = ["U1", "U2", "X1"]
    Complete_DAG["W"] = ["X1", "X2"]
    Complete_DAG["Ydigit1"] = ["W"]
    Complete_DAG["Ydigit2"] = ["W"]
    Complete_DAG["Ycolor"] = ["U2", "W"]
    Complete_DAG["Ythick"] = ["W"]
    # Complete_DAG["ImgYdigit1"] = ["Ydigit1", "Ycolor", "Ythick"]
    # Complete_DAG["ImgYdigit2"] = ["Ydigit2", "Ycolor", "Ythick"]
    complete_labels = list(Complete_DAG.keys())


    Observed_DAG = {}
    Observed_DAG["X1"] = []
    Observed_DAG["X2"] = ["X1"]
    Observed_DAG["W"] = ["X1", "X2"]
    Observed_DAG["Ydigit1"] = ["W"]
    Observed_DAG["Ydigit2"] = ["W"]
    Observed_DAG["Ycolor"] = ["W"]
    Observed_DAG["Ythick"] = ["W"]
    Observed_DAG["ImgYdigit1"] = ["Ydigit1", "Ycolor", "Ythick"]
    Observed_DAG["ImgYdigit2"] = ["Ydigit2", "Ycolor", "Ythick"]
    label_names = list(Observed_DAG.keys())

    image_labels= ["ImgYdigit1", "ImgYdigit2"]


    label_dim = {
        "U1": {"feature": latent_state},
        "U2": {"feature": latent_state},
        "X1": {"feature": 2},  # [0,1]
        "X2": {"feature": 9},  # [0,8]
        "X1p": {"feature": 2},  # [0,1]
        "X2p": {"feature": 9},  # [0,8]
        "W": {"feature": 10},  # [0,9]
        "Ydigit1": {"feature": 10},  # [0,9]
        "Ydigit2": {"feature": 10},  # [0,9]
        "Ycolor": {"feature": 3},  # [0,2]
        "Ythick": {"feature": 2}  # [0,1]
    }

    intervention_list = [
        {"obs":["X2", "W"], "inter_vars": ["X1"]}]

    intervention_list[0]["expr"]= getdoKey(["X2", "W"], ["X1"])

    #     # {"vars": ["X1", "X2"], "expr": "P(V|do(X1,X2))"}]
    #
    interv_queries = []
    for intervention in intervention_list:
        perms = generate_permutations([label_dim[lb]["feature"] for lb in intervention["inter_vars"]])
        key_val = [dict(zip(intervention["inter_vars"], comb)) for comb in perms]
        interv_queries.append({"obs": intervention["obs"], "intervs": key_val, "expr": intervention["expr"]})
        # {'obs': ['X2', 'W', 'Ydigit1', 'Ydigit2', 'Ycolor', 'Ythick'], 'intervs': [{'X1': 0}, {'X1': 1}], 'expr': 'P(V|do(X1))'}



    # cf_list = [
    #     {"intv": ["X1", "X2"], "evid": ["X1p", "X2p"], "expr": "P(Ycolor|do(X1,X2),X1p, X2p)"}]
    #
    # obs_vars = ["Ycolor"]
    cf_queries = []
    # for cf in cf_list:
    #     perms = generate_permutations([label_dim[lb]["feature"] for lb in cf["intv"]]).tolist()
    #
    #     intv_key_val = [dict(zip(cf["intv"], comb)) for comb in perms]
    #
    #     perms = generate_permutations([label_dim[lb]["feature"] for lb in cf["evid"]]).tolist()
    #     ev_key_val = [dict(zip(cf["evid"], comb)) for comb in perms]
    #
    #     cf_queries.append({"obs": obs_vars, "intervs": intv_key_val, "evidence":ev_key_val, "expr": cf["expr"]})
        # {'obs': ['X2', 'W', 'Ydigit1', 'Ydigit2', 'Ycolor', 'Ythick'], 'intervs': [{'X1': 0}, {'X1': 1}], 'expr': 'P(V|do(X1))'}





    # interv_queries = [
    #
    #     {"obs": ["X2","W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"], "interv": [{"X1": 0}, {"X1":1}], "expr": "P(V|do(X1))"},
    #
    #
    #     {"obs": ["X2"], "interv": {"X1":0}, "expr": "P(X2|do(X1=0))"},
    #     {"obs": ["X2"], "interv": {"X1":1}, "expr": "P(X2|do(X1=1))"},
    #     {"obs": ["W"], "interv": {"X1":0}, "expr": "P(W|do(X1=0))"},
    #     {"obs": ["W"], "interv": {"X1":1}, "expr": "P(W|do(X1=1))"},
    #
    #     {"obs": ["X2", "Ycolor"], "interv": {"X1":0, "W":6}, "expr": "P(X2,Ycolor|do(X1=0,W=6))"},
    #     {"obs": ["X2", "Ycolor"], "interv": {"X1":1, "W":6}, "expr": "P(X2,Ycolor|do(X1=1,W=6))"},
    #     {"obs": ["X2", "Ycolor"], "interv": {"X1":0, "W":9}, "expr": "P(X2,Ycolor|do(X1=0,W=9))"},
    #     {"obs": ["X2", "Ycolor"], "interv": {"X1":1, "W":9}, "expr": "P(X2,Ycolor|do(X1=1,W=9))"}
    # ]



    # latent_conf = {"X1": ["U1"], "X2": ["U1", "U2"], "W": [], "Ydigit1": [], "Ydigit2": [], "Ycolor": ["U2"],
    #                "Ythick": [], "ImgYdigit1": ["U2"], "ImgYdigit2": []}
    # confTochild = {"U1": ["X1", "X2"], "U2": ["X2", "Ycolor", "ImgYdigit1"]}

    latent_conf = {"X1": ["U1"], "X2": ["U1", "U2"], "W": [], "Ydigit1": [], "Ydigit2": [], "Ycolor": ["U2"],
                   "Ythick": [], "ImgYdigit1":[], "ImgYdigit2":[]}
    #
    confTochild = {"U1": ["X1", "X2"], "U2": ["X2", "Ycolor"]}


    exogenous = {}
    for label in label_names:
        if label not in image_labels:
            exogenous[label] = "n" + label



    # counterfactual variables
    cflabel_names = ["U1", "U2", "X1", "X1p", "X2", "X2p", "W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]
    Twin_Network = {}
    Twin_Network["U1"] = []
    Twin_Network["U2"] = []
    Twin_Network["X1"] = []
    Twin_Network["X2"] = []
    Twin_Network["X1p"] = []
    Twin_Network["X2p"] = ["U1", "U2", "X1p"]
    Twin_Network["W"] = ["X1", "X2"]
    Twin_Network["Ydigit1"] = ["W"]
    Twin_Network["Ydigit2"] = ["W"]
    Twin_Network["Ycolor"] = ["U2", "W"]
    Twin_Network["Ythick"] = ["W"]
    cf_exogenous = {"X2p": "nX2", "W": "nW", "Ydigit1": "nYdigit1", "Ydigit2": "nYdigit2", "Ycolor": "nYcolor","Ythick": "nYthick"}

    cf_intervene = {"X1": 1, "X2": 5}
    # cf_observe = ["Ydigit1", "Ydigit2", "Ycolor", "Ythick"]
    cf_observe = [ "Ycolor"]
    cf_evidence = {"X1p": 1, "X2p": 1}



    twin_map = {"X1p": "X1", "X1": "X1p", "X2p": "X2", "X2": "X2p"}



    noise_params = {"nX1": (0.1, noise_states),
                    "nX2": (0.1, noise_states),
                    "nW": (0.1, noise_states),
                    "nYdigit1": (0.1, noise_states),
                    "nYdigit2": (0.1, noise_states),
                    "nYcolor": (0.1, noise_states),
                    "nYthick": (0.1, noise_states),
                    # "nImgYdigit1": (0.1, noise_states),
                    # "nImgYdigit2": (0.1, noise_states),

                    "U1": (1, latent_state),
                    "U2": (1, latent_state)}


    # mechanism training
    intervention_datavar = []  # I cant concatenate different intvened variables distributions.


    train_mech_dict={}
    for dist in Data_intervs:
        comp_dict= build_compares(confTochild, Observed_DAG, label_names, list(dist.keys()))
        for label in label_names:
            if label not in train_mech_dict:
                train_mech_dict[label]=[]

            mech_dict = {"parents": Observed_DAG[label], "intv": dist, "compare":comp_dict[label]}
            if label in image_labels:
                continue
            train_mech_dict[label].append(mech_dict)

    #image labels are a little different different then labels as mechanism itself is not included in "compare"
    # train_mech_dict["ImgYdigit1"]=[{'parents': ['W'], 'intv': {}, 'compare': ['X1', 'X2', 'W']}]
    # train_mech_dict["ImgYdigit2"]=[{'parents': ['W'], 'intv': {}, 'compare': ['W']}] #ideally I need ImgYdigit2 also.

    print("printing")
    for label in label_names:
        print(label, train_mech_dict[label])


    for label in Observed_DAG:
        if label not in image_labels:
            label_dim["n" + label] = {"feature": noise_states}

    return DAG_desc, Complete_DAG_desc, Complete_DAG, complete_labels, Observed_DAG, label_names, image_labels, interv_queries, cf_queries, latent_conf, \
           confTochild, exogenous, cf_intervene, cf_observe, cf_evidence, cflabel_names, twin_map, Twin_Network, cf_exogenous, \
           noise_params, train_mech_dict, label_dim



